{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy\n",
    "import scipy.special"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NeuralNetwork:\n",
    "    def __init__(self, inputnodes, hiddennodes, outputnodes, learningRate):\n",
    "        #初始化网络，设置输入层、中间层和输出层的节点数\n",
    "        self.input_nodes = inputnodes\n",
    "        self.hidden_nodes = hiddennodes\n",
    "        self.output_nodes = outputnodes\n",
    "        self.lr = learningRate\n",
    "        \n",
    "        #初始化权重矩阵，一个是wih，一个是who\n",
    "        self.wih = numpy.random.rand(self.hidden_nodes, self.input_nodes) - 0.5\n",
    "        self.who = numpy.random.rand(self.output_nodes, self.hidden_nodes) - 0.5\n",
    "        pass\n",
    "    \n",
    "    def fit(self, inputs_list, targets_list):\n",
    "        #1.计算输出结果\n",
    "        #把inputs_list,targets_list转成二维矩阵，.T代表转置\n",
    "        inputs = numpy.array(inputs_list, ndmin = 2).T\n",
    "        targets = numpy.array(targets_list, ndmin = 2).T\n",
    "        #计算信号经过输入层后产生的信号量\n",
    "        hidden_inputs = numpy.dot(self.wih, inputs)\n",
    "        #中间层神经元对输入的信号做激活函数后得到输出信号\n",
    "        #sigmoid对应于scipy.special.expit(x)\n",
    "        sigmoid = lambda x:scipy.special.expit(x)\n",
    "        hidden_outputs = sigmoid(hidden_inputs)\n",
    "        #输出层接受中间层的信号量\n",
    "        final_inputs = numpy.dot(self.who, hidden_outputs)\n",
    "        #输出层输出信号\n",
    "        final_outputs = sigmoid(final_inputs)\n",
    "        \n",
    "        #2.计算误差\n",
    "        output_errors = targets - final_outputs\n",
    "        #3.反向误差传播\n",
    "        hidden_errors = numpy.dot(self.who.T, output_errors)\n",
    "        #4. 根据误差计算链路权重的更新量，把更新加到原来的链路权重上\n",
    "        self.who += self.lr * numpy.dot((output_errors * final_outputs * (1 - final_outputs)),\n",
    "                                       numpy.transpose(hidden_outputs))\n",
    "        self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1 - hidden_outputs)),\n",
    "                                       numpy.transpose(inputs))\n",
    "        pass\n",
    "    \n",
    "    def evaluate(self, inputs):\n",
    "        #输入新数据，网络给出对新数据的判断结果\n",
    "        hidden_inputs = numpy.dot(self.wih, inputs)\n",
    "        #sigmoid对应于scipy.special.expit(x)\n",
    "        sigmoid = lambda x:scipy.special.expit(x)\n",
    "        hidden_outputs = sigmoid(hidden_inputs)\n",
    "        #计算输出层神经元接收到的信号量\n",
    "        final_inputs = numpy.dot(self.who, hidden_outputs);\n",
    "        #调用sigmoid计算最外层神经元的输出信号\n",
    "        final_outputs = sigmoid(final_inputs)\n",
    "        return final_outputs\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/home/lancer/anaconda3/envs/keras1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/lancer/anaconda3/envs/keras1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/lancer/anaconda3/envs/keras1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/lancer/anaconda3/envs/keras1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/lancer/anaconda3/envs/keras1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/lancer/anaconda3/envs/keras1/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "/home/lancer/anaconda3/envs/keras1/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/lancer/anaconda3/envs/keras1/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/lancer/anaconda3/envs/keras1/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/lancer/anaconda3/envs/keras1/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/lancer/anaconda3/envs/keras1/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/lancer/anaconda3/envs/keras1/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
     ]
    }
   ],
   "source": [
    "#加载数据\n",
    "from keras.datasets import mnist\n",
    "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
    "from keras.utils import to_categorical\n",
    "\n",
    "#将图片像素点值转换到区间【0， 1】\n",
    "train_images = train_images.reshape(60000, 28*28)\n",
    "train_images = train_images.astype('float32') / 255\n",
    "test_images = test_images.reshape(10000, 28*28)\n",
    "test_images = test_images.astype('float32') / 255\n",
    "\n",
    "#将标注转换为向量格式\n",
    "train_labels = to_categorical(train_labels)\n",
    "test_labels = to_categorical(test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#初始化网络，输入层含有784个节点\n",
    "input_nodes = 784\n",
    "#中间层含有100个节点\n",
    "hidden_nodes = 100\n",
    "#输出层含有10个节点\n",
    "output_nodes = 10\n",
    "#学习率设置为0.3\n",
    "learning_rate = 0.3\n",
    "n = NeuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)\n",
    "\n",
    "#将训练图片与图片对应的标签配对后输入网络进行训练\n",
    "for train_image, train_label in zip(train_images, train_labels):\n",
    "    n.fit(train_image, train_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "performace =  0.9439\n"
     ]
    }
   ],
   "source": [
    "scores = []\n",
    "#使用测试数据检测网络训练结果， 计算网络判断准确率\n",
    "for test_image, test_label in zip(test_images, test_labels):\n",
    "    output = n.evaluate(test_image)\n",
    "    evaluate_label = numpy.argmax(output)\n",
    "    correct_label = numpy.argmax(test_label)\n",
    "    #print(evaluate_label, correct_label)\n",
    "    #预测正确添加1， 错误添加0，最后得到准确率\n",
    "    if evaluate_label == correct_label:\n",
    "        scores.append(1)\n",
    "    else:\n",
    "        scores.append(0)\n",
    "\n",
    "scores_array = numpy.asarray(scores)\n",
    "print(\"performace = \", scores_array.sum()/scores_array.size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fb815ab9320>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAANL0lEQVR4nO3dXahd9ZnH8d9vYqPBFs0xRw1p9MQieHRwknKIQaU4lAm+XMRcODRKyaBMeqHSYi98mYtGQQzDtDUXQyGdxKTasRTamAgyNoSKKWjwKGc0meAcjWea1JjsEDBWhGryzMVZmTnGs9fZ7rX2S/J8P3DYe69nvTxs8svae//X3n9HhACc/f6q1w0A6A7CDiRB2IEkCDuQBGEHkjinmwebN29eDA0NdfOQQCoTExM6evSop6tVCrvtmyWtlzRL0r9FxLqy9YeGhjQ6OlrlkABKjIyMNK21/TLe9ixJ/yrpFklXS1pl++p29wegs6q8Z18q6Z2I2B8Rf5H0K0kr6mkLQN2qhH2BpANTHh8sln2O7TW2R22PNhqNCocDUEWVsE/3IcAXrr2NiA0RMRIRI4ODgxUOB6CKKmE/KGnhlMdfl/R+tXYAdEqVsL8m6Urbi2zPlvQdSdvraQtA3doeeouIz2zfJ+lFTQ69bYqIvbV1BqBWlcbZI+IFSS/U1AuADuJyWSAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASlaZstj0h6SNJJyR9FhEjdTQFoH6Vwl7424g4WsN+AHQQL+OBJKqGPST9zvbrttdMt4LtNbZHbY82Go2KhwPQrqphvyEivinpFkn32v7W6StExIaIGImIkcHBwYqHA9CuSmGPiPeL2yOStkpaWkdTAOrXdthtn2/7a6fuS1ouaU9djQGoV5VP4y+RtNX2qf38e0T8Ry1dAahd22GPiP2S/qbGXgB0EENvQBKEHUiCsANJEHYgCcIOJFHHF2FSePXVV5vW1q9fX7rtggULSutz5swpra9evbq0PjAw0FYNuXBmB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkGGdvUdlY9/j4eEeP/fjjj5fWL7jggqa1ZcuW1d3OGWNoaKhp7eGHHy7d9rLLLqu5m97jzA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDO3qLnnnuuaW1sbKx022uuuaa0vnfv3tL67t27S+vbtm1rWnvxxRdLt120aFFp/b333iutV3HOOeX//ObPn19aP3DgQNvHLhuDl6QHH3yw7X33K87sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+wtGh4ebqvWimuvvba0vmrVqtL6unXrmtYmJiZKt51pnH3//v2l9Spmz55dWp9pnH2m3huNRtPaVVddVbrt2WjGM7vtTbaP2N4zZdmA7R22x4vbuZ1tE0BVrbyM3yzp5tOWPSRpZ0RcKWln8RhAH5sx7BHxsqRjpy1eIWlLcX+LpNvrbQtA3dr9gO6SiDgkScXtxc1WtL3G9qjt0bL3UAA6q+OfxkfEhogYiYiRwcHBTh8OQBPthv2w7fmSVNweqa8lAJ3Qbti3Szr128qrJTX/jiWAvjDjOLvtZyXdJGme7YOSfiRpnaRf275H0h8l3dHJJlHuvPPOa1qrOp5c9RqCKmb6Hv/Ro0dL69ddd13T2vLly9vq6Uw2Y9gjotkVHd+uuRcAHcTlskAShB1IgrADSRB2IAnCDiTBV1zRMx9//HFpfeXKlaX1kydPltaffPLJprU5c+aUbns24swOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kwzo6e2bx5c2n9gw8+KK1fdNFFpfXLL7/8y7Z0VuPMDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM6Ojnr33Xeb1h544IFK+37llVdK65deemml/Z9tOLMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMs6Ojnn/++aa1Tz/9tHTbO+4onwn8iiuuaKunrGY8s9veZPuI7T1Tlq21/SfbY8XfrZ1tE0BVrbyM3yzp5mmW/zQiFhd/L9TbFoC6zRj2iHhZ0rEu9AKgg6p8QHef7TeLl/lzm61ke43tUdujjUajwuEAVNFu2H8m6RuSFks6JOnHzVaMiA0RMRIRI4ODg20eDkBVbYU9Ig5HxImIOCnp55KW1tsWgLq1FXbb86c8XClpT7N1AfSHGcfZbT8r6SZJ82wflPQjSTfZXiwpJE1I+l7nWkQ/m2msfOvWrU1r5557bum2TzzxRGl91qxZpXV83oxhj4hV0yze2IFeAHQQl8sCSRB2IAnCDiRB2IEkCDuQBF9xRSUbN5YPzOzatatp7c477yzdlq+w1oszO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwTg7So2NjZXW77///tL6hRde2LT22GOPtdER2sWZHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJw9uU8++aS0vmrVdD8u/P9OnDhRWr/rrrua1vi+endxZgeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnP8udPHmytH7bbbeV1t9+++3S+vDwcGn90UcfLa2je2Y8s9teaPv3tvfZ3mv7+8XyAds7bI8Xt3M73y6AdrXyMv4zST+MiGFJyyTda/tqSQ9J2hkRV0raWTwG0KdmDHtEHIqIN4r7H0naJ2mBpBWSthSrbZF0e4d6BFCDL/UBne0hSUsk7ZZ0SUQckib/Q5B0cZNt1tgetT3aaDQqtgugXS2H3fZXJf1G0g8i4nir20XEhogYiYiRwcHBdnoEUIOWwm77K5oM+i8j4rfF4sO25xf1+ZKOdKZFAHWYcejNtiVtlLQvIn4ypbRd0mpJ64rbbR3pEJUcO3astP7SSy9V2v/TTz9dWh8YGKi0f9SnlXH2GyR9V9JbtseKZY9oMuS/tn2PpD9KuqMjHQKoxYxhj4g/SHKT8rfrbQdAp3C5LJAEYQeSIOxAEoQdSIKwA0nwFdezwIcffti0tmzZskr7fuaZZ0rrS5YsqbR/dA9ndiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2s8BTTz3VtLZ///5K+77xxhtL65M/d4AzAWd2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYzwPj4eGl97dq13WkEZzTO7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQRCvzsy+U9AtJl0o6KWlDRKy3vVbSP0pqFKs+EhEvdKrRzHbt2lVaP378eNv7Hh4eLq3PmTOn7X2jv7RyUc1nkn4YEW/Y/pqk123vKGo/jYh/6Vx7AOrSyvzshyQdKu5/ZHufpAWdbgxAvb7Ue3bbQ5KWSNpdLLrP9pu2N9me22SbNbZHbY82Go3pVgHQBS2H3fZXJf1G0g8i4rikn0n6hqTFmjzz/3i67SJiQ0SMRMTI4OBg9Y4BtKWlsNv+iiaD/suI+K0kRcThiDgREScl/VzS0s61CaCqGcPuyZ8P3ShpX0T8ZMry+VNWWylpT/3tAahLK5/G3yDpu5Lesj1WLHtE0irbiyWFpAlJ3+tAf6jo+uuvL63v2LGjtM7Q29mjlU/j/yBpuh8HZ0wdOINwBR2QBGEHkiDsQBKEHUiCsANJEHYgCX5K+gxw9913V6oDEmd2IA3CDiRB2IEkCDuQBGEHkiDsQBKEHUjCEdG9g9kNSf8zZdE8SUe71sCX06+99WtfEr21q87eLo+IaX//rath/8LB7dGIGOlZAyX6tbd+7Uuit3Z1qzdexgNJEHYgiV6HfUOPj1+mX3vr174kemtXV3rr6Xt2AN3T6zM7gC4h7EASPQm77Zttv237HdsP9aKHZmxP2H7L9pjt0R73ssn2Edt7piwbsL3D9nhxO+0cez3qba3tPxXP3ZjtW3vU20Lbv7e9z/Ze298vlvf0uSvpqyvPW9ffs9ueJem/Jf2dpIOSXpO0KiL+q6uNNGF7QtJIRPT8Agzb35L0Z0m/iIi/Lpb9s6RjEbGu+I9ybkQ82Ce9rZX0515P413MVjR/6jTjkm6X9A/q4XNX0tffqwvPWy/O7EslvRMR+yPiL5J+JWlFD/roexHxsqRjpy1eIWlLcX+LJv+xdF2T3vpCRByKiDeK+x9JOjXNeE+fu5K+uqIXYV8g6cCUxwfVX/O9h6Tf2X7d9ppeNzONSyLikDT5j0fSxT3u53QzTuPdTadNM943z107059X1YuwTzeVVD+N/90QEd+UdIuke4uXq2hNS9N4d8s004z3hXanP6+qF2E/KGnhlMdfl/R+D/qYVkS8X9wekbRV/TcV9eFTM+gWt0d63M//6adpvKebZlx98Nz1cvrzXoT9NUlX2l5ke7ak70ja3oM+vsD2+cUHJ7J9vqTl6r+pqLdLWl3cXy1pWw97+Zx+mca72TTj6vFz1/PpzyOi63+SbtXkJ/LvSvqnXvTQpK8rJP1n8be3171JelaTL+s+1eQronskXSRpp6Tx4nagj3p7WtJbkt7UZLDm96i3GzX51vBNSWPF3629fu5K+urK88blskASXEEHJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0n8Lx5q4VTxgWLnAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "#从测试数据集中抽取一张图片\n",
    "import matplotlib.pyplot as py\n",
    "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
    "image = test_images[0]\n",
    "py.imshow(image, cmap = 'Greys', interpolation = 'None')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.25044647e-04 9.11956237e-06 6.57984316e-05 1.08404111e-05\n",
      " 4.04355599e-08 1.58195257e-06 2.63088420e-07 9.99913022e-01\n",
      " 1.67263818e-07 3.15898623e-05]\n",
      "number of image is:  7\n"
     ]
    }
   ],
   "source": [
    "#检验网络对图片的判断是否正确\n",
    "test_images = test_images.reshape(10000, 28*28)\n",
    "test_images = test_images.astype('float32') / 255\n",
    "result = n.evaluate(test_images[0])\n",
    "print(result)\n",
    "print(\"number of image is: \", numpy.argmax(result))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = [1.79370211e-04,1.72383640e-06,6.41083102e-03,8.71855928e-05,\n",
    " 2.27669657e-07,7.90085970e-06,2.19789172e-08,9.99642792e-01,\n",
    " 5.71201258e-06,1.09405446e-06]\n",
    "\n",
    "numpy.argmax(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
